{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 4: Transfer Learning And Other Tricks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.models as models\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_model = models.resnet50(pretrained=True) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Freezing parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, param in transfer_model.named_parameters():\n",
    "    if(\"bn\" not in name):\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Replacing the classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_model.fc = nn.Sequential(nn.Linear(transfer_model.fc.in_features,500),\n",
    "nn.ReLU(),                                 \n",
    "nn.Dropout(), nn.Linear(500,2)) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device=\"cpu\"):\n",
    "    for epoch in range(epochs):\n",
    "        training_loss = 0.0\n",
    "        valid_loss = 0.0\n",
    "        model.train()\n",
    "        for batch in train_loader:\n",
    "            optimizer.zero_grad()\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            output = model(inputs)\n",
    "            loss = loss_fn(output, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            training_loss += loss.data.item() * inputs.size(0)\n",
    "        training_loss /= len(train_loader.dataset)\n",
    "        \n",
    "        model.eval()\n",
    "        num_correct = 0 \n",
    "        num_examples = 0\n",
    "        for batch in val_loader:\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            output = model(inputs)\n",
    "            targets = targets.to(device)\n",
    "            loss = loss_fn(output,targets) \n",
    "            valid_loss += loss.data.item() * inputs.size(0)\n",
    "            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)\n",
    "            num_correct += torch.sum(correct).item()\n",
    "            num_examples += correct.shape[0]\n",
    "        valid_loss /= len(val_loader.dataset)\n",
    "\n",
    "        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,\n",
    "        valid_loss, num_correct / num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_image(path):\n",
    "    try:\n",
    "        im = Image.open(path)\n",
    "        return True\n",
    "    except:\n",
    "        return False\n",
    "img_transforms = transforms.Compose([\n",
    "    transforms.Resize((64,64)),    \n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                    std=[0.229, 0.224, 0.225] )\n",
    "    ])\n",
    "train_data_path = \"./train/\"\n",
    "train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)\n",
    "val_data_path = \"./val/\"\n",
    "val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms, is_valid_file=check_image)\n",
    "batch_size=64\n",
    "train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
    "val_data_loader  = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\") \n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(val_data_loader.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transfer_model.to(device)\n",
    "optimizer = optim.Adam(transfer_model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(transfer_model, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=5, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LR Finder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_lr(model, loss_fn, optimizer, train_loader, init_value=1e-8, final_value=10.0, device=\"cpu\"):\n",
    "    number_in_epoch = len(train_loader) - 1\n",
    "    update_step = (final_value / init_value) ** (1 / number_in_epoch)\n",
    "    lr = init_value\n",
    "    optimizer.param_groups[0][\"lr\"] = lr\n",
    "    best_loss = 0.0\n",
    "    batch_num = 0\n",
    "    losses = []\n",
    "    log_lrs = []\n",
    "    for data in train_loader:\n",
    "        batch_num += 1\n",
    "        inputs, targets = data\n",
    "        inputs = inputs.to(device)\n",
    "        targets = targets.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = loss_fn(outputs, targets)\n",
    "\n",
    "        # Crash out if loss explodes\n",
    "\n",
    "        if batch_num > 1 and loss > 4 * best_loss:\n",
    "            if(len(log_lrs) > 20):\n",
    "                return log_lrs[10:-5], losses[10:-5]\n",
    "            else:\n",
    "                return log_lrs, losses\n",
    "\n",
    "        # Record the best loss\n",
    "\n",
    "        if loss < best_loss or batch_num == 1:\n",
    "            best_loss = loss\n",
    "\n",
    "        # Store the values\n",
    "        losses.append(loss.item())\n",
    "        log_lrs.append((lr))\n",
    "\n",
    "        # Do the backward pass and optimize\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # Update the lr for the next step and store\n",
    "\n",
    "        lr *= update_step\n",
    "        optimizer.param_groups[0][\"lr\"] = lr\n",
    "    if(len(log_lrs) > 20):\n",
    "        return log_lrs[10:-5], losses[10:-5]\n",
    "    else:\n",
    "        return log_lrs, losses\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(lrs, losses) = find_lr(transfer_model, torch.nn.CrossEntropyLoss(),optimizer, train_data_loader,device=device)\n",
    "plt.plot(lrs, losses)\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.xlabel(\"Learning rate\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom Transforms\n",
    "\n",
    "Here we'll create a lambda transform and a custom transform class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _random_colour_space(x):\n",
    "    output = x.convert(\"HSV\")\n",
    "    return output "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colour_transform = transforms.Lambda(lambda x: _random_colour_space(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_colour_transform = torchvision.transforms.RandomApply([colour_transform])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Noise():\n",
    "    \"\"\"Adds gaussian noise to a tensor.\n",
    "    \n",
    "    Example:\n",
    "        >>> transforms.Compose([\n",
    "        >>>     transforms.ToTensor(),\n",
    "        >>>     Noise(0.1, 0.05)),\n",
    "        >>> ])\n",
    "    \n",
    "    \"\"\"\n",
    "    def __init__(self, mean, stddev):\n",
    "        self.mean = mean\n",
    "        self.stddev = stddev\n",
    "\n",
    "    def __call__(self, tensor):\n",
    "        noise = torch.zeros_like(tensor).normal_(self.mean, self.stddev)\n",
    "        return tensor.add_(noise)\n",
    "    \n",
    "    def __repr__(self):\n",
    "        repr = f\"{self.__class__.__name__  }(mean={self.mean},sttdev={self.stddev})\"\n",
    "        return repr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_transform_pipeline = transforms.Compose([random_colour_transform, Noise(0.1, 0.05)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ensembles\n",
    "\n",
    "Given a list of models, we can produce predictions for each model and then make an average to make a final prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_ensemble = [models.resnet50().to(device), models.resnet50().to(device)]\n",
    "predictions = [F.softmax(m(torch.rand(1,3,224,244).to(device))) for m in models_ensemble] \n",
    "avg_prediction = torch.stack(predictions).mean(0).argmax()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.stack(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
